"""Attention Interface to handle various attention operators and cache operations.

This module provides an interface between the high-level runtime and cache management system and
the low-level functional attention operators. The interface is designed to provide a homogeneous
object-oriented interface to the high-level runtime via the SequenceInfo dataclass. The SequenceInfo
is also responsible for functionalizing information about the sequence and pass it on the the
various attention interface. The AttentionDescriptor is the main interface to the attention operator
and operates on a purely functional paradigm that is compatible with the torch custom op system.

"""

from abc import ABC, abstractmethod
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union

import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator
from torch._ops import OpOverloadPacket
from torch.fx import Node
from torch.types import Number

from ...._utils import nvtx_range
from ..utils.logger import ad_logger

Constant = Union[int, float, str, None]


class CacheConfig(BaseModel):
    """Cache configuration for attention-related dtypes."""

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        extra="forbid",
    )

    dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.")
    mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.")
    delta_dtype: Optional[torch.dtype] = Field(
        default=torch.float32, description="Delta cache dtype. Defaults to float32."
    )

    @field_validator("dtype", "mamba_dtype", "delta_dtype", mode="before")
    @classmethod
    def _coerce_dtype(cls, value):
        if value is None or isinstance(value, torch.dtype):
            return value
        if isinstance(value, str):
            dtype = getattr(torch, value, None)
            assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
            return dtype
        return value

    def __or__(self, other: "CacheConfig") -> "CacheConfig":
        """Combine two CacheConfig objects field-wise using Python's `or` semantics.

        For each field, selects the first non-None value between `self` and `other`.
        """
        if not isinstance(other, CacheConfig):
            raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}")
        merged_kwargs = {}
        for field_name in type(self).model_fields.keys():
            merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name)
        return CacheConfig(**merged_kwargs)


class SequenceInfo:
    """An interface to hold information about how the sequence is laid out and stored in cache.

    We assume the sequence + cache is laid out in the following way. Also note that we differentiate
    between arguments that are originally part of the model/graph and arguments that are needed for
    the attention operator when we switch to cached+flattened attention.

    ### ORIGINAL MODEL ARGUMENTS ###################################################################
    - input_ids: [id_0, ..., id_{s_total-1}]
      flattened sequence of [b, 1] or [1, s_total]. We use [b, 1] to denote generate-only batches.
    - position_ids: [pos_0, ..., pos_{s_total-1}]
      flattened sequence of [b, 1] or [1, s_total] indicating absolute position ids for every token
      in the input_ids sequence. We use [b, 1] to denote generate-only batches.

    NOTE: ``input_ids`` and ``position_ids`` are initially expected to be of shape [b, seq_len]
    before we switch to cached+flattened attention.

    ### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ##################################################
    Those are extra arguments that can be provided to the interface and they are stored as follows:
    - _extra_args: dictionary of extra arguments with currently active values.

    ### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############
    - seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i)
      Describes how long each sequence is. For example,
      input_ids[:s_0] will correspond to sequence 0 in the batch and input_ids[s_0:s_1] will
      correspond to sequence 1 in the batch.
    - input_pos: [pos_0, ..., pos_{b-1}]
      Corresponds to the total number of tokens that has been already been cached for each sequence
      in the batch.
    - cache_loc: [c0, ...., c_{np-1}] where np is total number of pages allocated to describe all
      sequences in the batch.
    - pages_per_seq: [ps_0, ps_1, ..., ps_{b-1}] where ps_i is the number of pages allocated for
      sequence i. Note that, for example, cache_loc[p_0:p_1] will correspond to the pages associated
      with sequence 1 in the batch.
    - slot_idx: [s_0, s_1, ..., s_{b-1}]
      Corresponds to the slot index of each sequence in the batch.

    ################################################################################################

    Here are a couple of notes to emphasize this notation:

    - The total number of allocated token space for sequence i is given by ps_i * page_size. This is
      the total number of tokens that can be cached for each sequence.

    - NOTE: It must hold that pos_i + s_i <= ps_i * page_size for all i in [0, b-1]. Moreover, it is
      the responsibility of the cache manager and/or runtime to ensure sufficient page allocation
      for each sequence.

    """

    def __init__(
        self,
        max_seq_len: int = 1,
        max_batch_size: int = 1,
        page_size: int = 0,
        max_num_tokens: Optional[int] = None,
        vocab_size_padded: Optional[int] = None,
        chunk_size: Optional[int] = None,
    ):
        """Initialize the SequenceInfo object.

        Args:
            max_seq_len: corresponds to the maximum sequence length of the input sequence. It
                includes the tokens in the input sequence and the tokens generated by the model.
            max_batch_size: corresponds to the maximum number of sequences (or requests) that the
                model can process.
            page_size: corresponds to the page size of the cache. For an unpaged cache, the page
                size should be set to max_seq_len. Also note that two sequences in a batch can not
                share a page.
            max_num_tokens: corresponds to the maximum number of tokens that the model can process
                across all sequences in the batch. If a batch is composed of context-only requests
                of input sequence length ISL, then the maximum number of sequences possible in the
                batch is min (max_batch_size, max_num_tokens // ISL). Similarly, if a batch is
                composed of generate-only requests, then the maximum number of sequences possible in
                the batch is min (max_batch_size, max_num_tokens).
            vocab_size_padded: corresponds to the padded vocabulary size of the model.
        Returns:
            None
        """
        # set up basic attributes
        self.max_seq_len = max_seq_len
        self.max_batch_size = max_batch_size
        self.page_size = page_size if page_size > 0 else max_seq_len
        self.vocab_size_padded = vocab_size_padded
        self.chunk_size = chunk_size
        # Chunk size is an input to a custom op, so we need to set a default value if it is not provided.
        if self.chunk_size is None:
            self.chunk_size = 128
        # NOTE (lucaslie): WAR to address issue when using flashinfer attention with
        # (max_batch_size, max_seq_len) input in trtllm runtime.
        # see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
        max_seq_len_adjusted = self.max_seq_len + 1

        # if the provided max_num_tokens is less than the max_batch_size * max_seq_len_adjusted,
        # we use the provided max_num_tokens. If max_num_tokens provided is more, we still use
        # max_batch_size * max_seq_len_adjusted since the extra tokens cannot be used.
        self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
        if max_num_tokens is not None and max_num_tokens > 0:
            self.max_num_tokens = min(self.max_num_tokens, max_num_tokens)

        # Num pages can not be less than max_batch_size.
        self._num_pages = max(
            self.max_batch_size,
            (self.max_num_tokens) // self.page_size  # floored number of pages
            + (self.max_num_tokens / self.max_batch_size % self.page_size > 0)  # check for overflow
            * self.max_batch_size,  # +1 page per sequence if overflow is required
        )
        # sanity check
        assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size"

        # cache_loc requires some special treatment due to block reuse. Note that the constraint for
        # cache_loc with block_reuse is as follows:
        # 0 <= cache_loc < num_pages
        # len(cache_loc) <= max_num_cache_loc_assignments
        max_num_cache_loc_assignments = (
            max_seq_len_adjusted // self.page_size + 1
        ) * self.max_batch_size

        # log parameters
        ad_logger.info(
            f"[SequenceInfo:] {self.max_seq_len=}, {self.max_batch_size=}, {self.page_size=}, "
            f"{self.max_num_tokens=} (inferred), {max_num_tokens=} (provided), {self.num_pages=}, "
            f"{max_num_cache_loc_assignments=}"
        )

        # indicator if extra args are activated that are needed for cached attention backends
        self._is_cached_attn = False

        # TENSOR FIELDS ############################################################################
        self._args_device: Dict[str, torch.Tensor] = {
            # TENSOR FIELDS FOR UNCACHED ATTENTION
            "input_ids": torch.ones(self.max_num_tokens, dtype=torch.int),
            "position_ids": torch.zeros(self.max_num_tokens, dtype=torch.long),
            # TENSOR FIELDS FOR CACHED ATTENTION
            "seq_len": torch.empty(self.max_batch_size, dtype=torch.int),
            "input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
            "cache_loc": torch.empty(max_num_cache_loc_assignments, dtype=torch.int),
            "pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
            "slot_idx": torch.empty(self.max_batch_size, dtype=torch.long),
            # OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
            "_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int),
        }
        self._args_host: Dict[str, List[int]] = {
            k: v.tolist() for k, v in self._args_device.items()
        }
        # NOTE: order of keys is relevant here!
        self._uncached_arg_names = ("input_ids", "position_ids")
        self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq", "slot_idx")
        # page_size is the size of attentionkv-cache pages.
        # chunk_size is used in mamba prefill kernels to split the context into chunks.
        self._cached_constants = ("page_size", "chunk_size")
        ############################################################################################

        # EXTRA TENSOR FIELDS ######################################################################
        self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
        ############################################################################################

        # call reset once to set a consistent initial state
        self.reset()

    @property
    def device(self) -> torch.device:
        return self._args_device["input_ids"].device

    def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:
        """Shape the tensor for the forward pass based on the current attention mode.

        Args:
            tnsr: The tensor to shape assumed to be in shape [batch_size*seq_len, ...]

        Returns:
            The shaped tensor flattened or unflattened based on the current attention mode.
        """
        # check if we are still running uncached attention in which case we are also still
        # operate on unflattened tensors with explicit [batch_size, seq_len, ...] shape
        # generate-only batches are also formatted like this (i.e. [b, 1])
        if not self._is_cached_attn or self.is_generate:
            bs = len(self.seq_len)
            sl = self.seq_len[0]
        # use [1,total_len] shape to indicate non-generate-only batch for cached attention
        else:
            bs, sl = 1, self.total_num_tokens

        # truncate to total tokens now, reshape, and return
        return tnsr[: self.total_num_tokens].view(bs, sl, *tnsr.shape[1:])

    def _named_args(
        self, include_extra_args: bool = True, include_cached_args: bool = True
    ) -> Dict[str, torch.Tensor]:
        # start with uncached args and shape them along the way
        args = {k: self._shape_for_forward(self._args_device[k]) for k in self._uncached_arg_names}

        # check other args to include
        if include_extra_args:
            args.update(self._extra_args)

        if include_cached_args:
            args.update({k: self._args_device[k] for k in self._cached_arg_names})

        return args

    @property
    def named_args(self) -> Dict[str, torch.Tensor]:
        """Return a dictionary of named arguments.

        These arguments contain all arguments that are managed by this interface and are required
        to run a model's forward pass including all extra arguments.

        Cached arguments are only included if the attention mode is cached to reflect that after
        switching to cached attention, the cached arguments are required for a forward pass.
        """
        return self._named_args(include_extra_args=True, include_cached_args=self._is_cached_attn)

    @property
    def named_standard_args(self) -> Dict[str, torch.Tensor]:
        """Return a dictionary of named standard arguments.

        We define standard arguments as the arguments that are part of the model's forward function
        by default (i.e., without the extra arguments).

        Just liked ``named_args``, this property includes cached attention arguments if the
        attention mode is cached.
        """
        return self._named_args(include_extra_args=False, include_cached_args=self._is_cached_attn)

    @property
    def args(self) -> Tuple[torch.Tensor, ...]:
        """Return a tuple of arguments."""
        return tuple(self.named_args.values())

    @property
    def args_for_prepare_metadata(self) -> Tuple[str, ...]:
        """Return a tuple of node/tensor arguments for the prepare_metadata op.

        The ``prepare_metadata`` interface expects the following arguments:

        1. ``args_for_prepare_metadata`` as nodes, i.e., as input-dependent tensors.
        2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args
           to the corresponding ``prepare_metadata`` node/op.

        This interface handles the tensor/node arguments part and can be used by compiler passes
        like ``insert_cached_attention`` to extract the constant arguments and add them to the
        ``prepare_metadata`` node/op.
        """
        # NOTE: for now we do _not_ include input_ids since we are not guaranteed that input_ids
        # is part of the graph, e.g., in situations where the graph is a submodule of the overall
        # model. In such instances, the graph usually sees inputs_embeds. However, we assume for
        # now that position_ids is always part of the graph.
        return ("position_ids",) + self._cached_arg_names

    @property
    def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]:
        """Return a tuple of extra (const, non-tensor) arguments for the prepare_metadata op.

        The ``prepare_metadata`` interface expects the following arguments:

        1. ``args_for_prepare_metadata`` as nodes, i.e., as input-dependent tensors.
        2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args
           to the corresponding ``prepare_metadata`` node/op.

        This interface handles the constant arguments part and can be used by compiler passes like
        ``insert_cached_attention`` to extract the constant arguments and add them to the
        ``prepare_metadata`` node/op.
        """
        return tuple(getattr(self, k) for k in self._cached_constants)

    @property
    def seq_len(self) -> List[int]:
        return self._args_host["seq_len"].copy()

    @property
    def input_pos(self) -> List[int]:
        return self._args_host["input_pos"].copy()

    @property
    def cache_loc(self) -> List[int]:
        return self._args_host["cache_loc"].copy()

    @property
    def pages_per_seq(self) -> List[int]:
        return self._args_host["pages_per_seq"].copy()

    @property
    def num_sequences(self) -> int:
        return len(self.seq_len)

    @property
    def total_num_tokens(self) -> int:
        return sum(self.seq_len)

    @property
    def is_generate(self) -> bool:
        return all(sl == 1 for sl in self.seq_len)

    @property
    def num_pages(self) -> int:
        return self._num_pages

    @num_pages.setter
    def num_pages(self, value):
        self._num_pages = value
        # update the cache_loc tensor
        if self._args_device["cache_loc"].numel() < value:
            self._args_device["cache_loc"].resize_(value)

    @property
    def is_paged(self) -> bool:
        return self.page_size < self.max_seq_len

    @property
    def page_assignments(self) -> List[List[int]]:
        """Return the page assignments for each sequence."""
        return self._get_page_assignments(self.cache_loc, self.pages_per_seq)

    @staticmethod
    def _get_page_assignments(
        cache_locations: List[int], pages_per_sequence: List[int]
    ) -> List[List[int]]:
        """Get nested page assignments from cache locations and pages per sequence as list of lists.

        Args:
            cache_locations: A flat list of cache locations for each sequence ordered by sequence.
            pages_per_sequence: A list of number of pages per sequence.

        Returns:
            A list of page assignments for each sequence ordered by sequence.
            For example:
                cache_locations: [0, 4, 2]
                pages_per_sequence: [2, 1]
                --> returns [[0, 4], [2]]
        """
        return [
            c_loc_one_seq.tolist()
            for c_loc_one_seq in torch.split(torch.tensor(cache_locations), pages_per_sequence)
        ]

    @staticmethod
    def _get_cache_locations_and_pages_per_sequence(
        page_assignments: List[List[int]],
    ) -> Tuple[List[int], List[int]]:
        """Get cache locations and pages per sequence from nested page assignments (lists of lists).

        Args:
            page_assignments: A list of page assignments for each sequence ordered by sequence.
        Returns:
            A tuple of:
                cache_locations: A flat list of cache locations for each sequence ordered by sequence.
                pages_per_sequence: A list of number of pages per sequence.

        Example:
            page_assignments: [[0, 4], [2]]
            --> returns ([0, 4, 2], [2, 1])

        """
        cache_loc_flat = [p_idx for pages in page_assignments for p_idx in pages]
        pages_per_seq = [len(p) for p in page_assignments]
        return cache_loc_flat, pages_per_seq

    @classmethod
    def _get_sanitized_seq_len(
        cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
    ) -> torch.Tensor:
        """Sanitize sequence lengths.

        We want to cover the following scenarios with this function:

        1. Pre-fill:
            input_ids: [1, s_total, ...]
            seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0]
            ---> returns [s_0, s_1, ..., s_{b-1}]
        2. Decode:
            input_ids: [b, 1, ...]
            seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
                     |---- b ----|--- (max_batch_size - b) ---|
            --> returns [1,] * b
        3. Decode in Cudagraph:
            input_ids: [b_cudagraph, 1, ...]
            seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
                     |---- b ----|--- (max_batch_size - b) ---|

            --> returns [1,] * b_cudagraph
            Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to
            b_cudagraph.

            # TODO: I could see one possible issue with this approach in the future.
            # If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location
            # information. What could happen is that the for the padded sequences the cache location
            # tensors point to allocated pages. This could lead to a situation where we write into
            # allocated cache pages polluting the cache of other sequences. Now this is not an issue
            # if we write the dummy sequences into unallocated cache pages... One fix could be to
            # pad not only the seq len but also pad the cache locations by just repeating the last
            # valid cache location in the batch. This would ensure that the dummy sequences just
            # repeats valid computation...
        """
        _, s = input_or_position_ids.shape[:2]
        num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
        if s > 1:
            return seq_len[:num_seq].detach().clone()
        else:
            return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)

    @staticmethod
    def _get_sanitized_num_sequences(
        input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
    ) -> int:
        """Get number of sequences.

        We makes sure that this function is compatible with both torch graph capture and cudagraph.
        Both can be a bit temparamental when trying to extract the number of sequences from a tensor
        with max_batch_size or max_batch_size*max_seq_len.
        """
        b, s = input_or_position_ids.shape[:2]
        if s > 1:
            num_seq = torch.sum(seq_len > 0)
            assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
        else:
            num_seq = b
        return num_seq

    def switch_to_cached_attn_inputs(self) -> List[str]:
        """Switch to inputs for cached+flattened attention operators.

        Returns:
            List[str]: List of new argument names that are now activated.

        This function will change the inputs provided by the interface from the arguments expected
        by regular attention in PyTorch (SDPA-style) to the arguments needed once we use attention
        operators with cache support and flattened sequences.

        NOTE: The graph inference optimizer is responsible for ensuring the the new inputs are
        correctly reflected in the graph after this function is called.
        """
        assert not self._is_cached_attn, "Cached+flattened attention already activated"
        self._is_cached_attn = True
        return list(self._cached_arg_names)

    def to(self, *args, **kwargs) -> None:
        def _move_dict(d: Dict[str, torch.Tensor]) -> None:
            for k, v in d.items():
                if v is not None:
                    d[k] = v.to(*args, **kwargs)

        _move_dict(self._args_device)
        _move_dict(self._extra_args)

    def set_example_sequence(
        self,
        input_ids: Optional[Sequence[Sequence[int]]] = None,
        position_ids: Optional[Sequence[Sequence[int]]] = None,
        **extra_args,
    ) -> None:
        """Set an example sequence useful for testing and export purposes without cache history."""
        # use a best guess default for input_ids if not provided
        if input_ids is None:
            bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len)
            input_ids = torch.ones(bs, seq_len, dtype=torch.int).tolist()

        # figure out page assignments
        pages_per_seq = [
            len(ids_one_seq) // self.page_size + (len(ids_one_seq) % self.page_size > 0)
            for ids_one_seq in input_ids
        ]
        cache_loc = list(range(sum(pages_per_seq)))
        page_assignments = self._get_page_assignments(cache_loc, pages_per_seq)

        # vanilla slot indices
        slot_idx = list(range(len(input_ids)))

        self.nest_sequences(
            input_ids,
            position_ids,  # will be auto-inferred if None
            input_pos=0,  # no cache history
            page_assignments=page_assignments,  # vanilla page assignments
            slot_idx=slot_idx,  # vanilla slot indices
            **extra_args,
        )

    def set_max_num_tokens_sample(self) -> None:
        """Set an example sequence with max_num_tokens."""
        # TODO (lucaslie): understand what this implies for extra arguments
        seq_len = self.max_num_tokens // self.max_batch_size
        input_ids = torch.ones(self.max_batch_size, seq_len, dtype=torch.int).tolist()
        self.set_example_sequence(input_ids)

    def set_generate_only_batch(self) -> None:
        """Set an example sequence for generate-only batch."""
        self.set_example_sequence([[1]] * self.max_batch_size)

    def reset(self) -> None:
        """Reset the sequence information.

        After reset the sequence information should correspond to a "generate-only" batch of
        sequences (b, s==1) without cache history.
        """
        self.set_generate_only_batch()

    @staticmethod
    def _flatten(nested_seqs: Sequence[Sequence[int]]) -> List[int]:
        return [
            val
            for lst in nested_seqs
            for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
        ]

    def _store_arg(
        self,
        name: str,
        tnsr_like: List[Number],
        reset_val: Optional[Number] = None,
    ) -> None:
        """Store the argument on the host and copy to the device in a non-blocking fashion.

        Args:
            name: Name of the argument to store.
            tnsr_like: List of values to store.
            reset_val: Value to reset/fill the full tensor on the device to before writing to it.
        """
        with nvtx_range(f"ad_store_seq_info_arg_{name}"):
            tnsr_device = self._args_device[name]

            # store list object on the host
            self._args_host[name] = tnsr_like.copy()

            # pin the memory on the host
            tnsr_host = torch.tensor(tnsr_like, dtype=tnsr_device.dtype, pin_memory=True)

            # check for available space
            assert tnsr_device.numel() >= tnsr_host.numel(), (
                f"device tensor {name} is too small, available: {tnsr_device.numel()}, "
                f"required: {tnsr_host.numel()}"
            )

            # reset/copy to the device in a non-blocking fashion
            if reset_val is not None:
                tnsr_device.fill_(reset_val)
            tnsr_device[: len(tnsr_like)].copy_(tnsr_host, non_blocking=True)

    def _store_extra_arg(
        self, name: str, tnsr_like: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]]
    ) -> None:
        with nvtx_range(f"ad_store_extra_arg_{name}"):
            if tnsr_like is not None:
                if not isinstance(tnsr_like, torch.Tensor):
                    if len(tnsr_like) > 1:
                        tnsr_like = torch.cat(tnsr_like)
                    else:
                        tnsr_like = tnsr_like[0]
                self._extra_args[name] = tnsr_like.to(self.device, non_blocking=True)
            else:
                self._extra_args[name] = None

    @nvtx_range("ad_get_unique_value")
    def _get_unique_value(self, occupied: Set[int], max_val: int) -> int:
        """Get un unoccupied value from the range indicated by max_val."""
        # Return the smallest free value; fall back to 0 if none
        for candidate in range(max_val):
            if candidate not in occupied:
                return candidate
        return 0

    @nvtx_range("ad_nest_sequences")
    def nest_sequences(
        self,
        input_ids: Sequence[Sequence[int]],
        position_ids: Optional[Sequence[Sequence[int]]] = None,
        input_pos: Optional[Union[Sequence[int], int]] = None,
        page_assignments: Optional[Sequence[Sequence[int]]] = None,
        slot_idx: Optional[Sequence[int]] = None,
        **extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]],
    ) -> None:
        """Create and store sequence information for the next forward pass.

        Args:
            input_ids: List of sequences of input_ids.
            position_ids: List of sequences of position_ids for each token.
            input_pos: Absolute starting position in the cache for each sequence.
            page_assignments: List of sequences of page assignments for each sequence.
            slot_idx: List of slot indices for each sequence.
            extra_args: Extra arguments to be stored in the interface.

        This i/f will ensure that all sequence info args are updated accordingly. Reset values are
        chosen as "neutral" values so that for cases like rounding up batch sizes for cudagraph we
        only write to unused buffers/caches.
        """
        ### UPDATE METADATA ########################################################################
        # update metadata first since it's useful for other updates to have up-to-date information

        # set new sequence lengths --> resetting the remaining entries to zero is important to help
        # us discern the actual number of sequences in the batch.
        self._store_arg("seq_len", [len(ids) for ids in input_ids], reset_val=0)

        # check for updated input_pos (i.e. cache start position)
        if input_pos is not None:
            self._store_arg(
                "input_pos",
                [input_pos] * self.num_sequences if isinstance(input_pos, int) else input_pos,
                reset_val=0,
            )

        # check for updated page_assignments
        if page_assignments is not None:
            cache_loc, pages_per_seq = self._get_cache_locations_and_pages_per_sequence(
                page_assignments
            )
            free_cache_loc = self._get_unique_value(set(cache_loc), self.num_pages)
            self._store_arg("cache_loc", cache_loc, reset_val=free_cache_loc)
            self._store_arg("pages_per_seq", pages_per_seq, reset_val=1)

        # check for updated slot_idx
        if slot_idx is not None:
            free_slot_idx = self._get_unique_value(set(slot_idx), self.max_batch_size)
            self._store_arg("slot_idx", slot_idx, reset_val=free_slot_idx)

        ### UPDATE MAIN INPUTS #####################################################################
        # set new input_ids and make sure to flatten it
        self._store_arg("input_ids", self._flatten(input_ids))

        # set new position_ids and make sure to flatten it
        if position_ids is None:
            position_ids = [
                [num for num in range(in_pos, in_pos + seq_len)]
                for in_pos, seq_len in zip(self.input_pos, self.seq_len)
            ]
        self._store_arg("position_ids", self._flatten(position_ids))

        ### UPDATE EXTRA INPUTS ####################################################################
        self._extra_args = {}
        for key, value in extra_args.items():
            self._store_extra_arg(key, value)

    @nvtx_range("ad_rescatter_input_ids")
    def rescatter_input_ids(
        self, ungathered_input_ids: torch.Tensor, gather_idx: List[int], scatter_ref: int
    ):
        """Re-scatter the provided ungathered input ids into the input_ids tensor.

        Args:
            ungathered_input_ids: The input ids on the device from which to gather.
            gather_idx: The list of indices to gather from the ungathered_input_ids.
            scatter_ref: The reference index to scatter to in input_ids via masked scatter.

        Returns:
            None

        This function will assume that we are in a generate-only batch.
        """
        # store the new gather indices
        self._store_arg("_gather_idx", gather_idx)

        # gather the provided input ids in a streaming fashion
        gather_ids_device = self._args_device["_gather_idx"][: len(gather_idx)]
        packed_input_ids = ungathered_input_ids[gather_ids_device]

        # re-scatter the provided input ids into the input_ids tensor
        input_ids_device = self._args_device["input_ids"]
        input_ids_device.masked_scatter_(input_ids_device == scatter_ref, packed_input_ids)

    @nvtx_range("ad_unnest_sequences")
    def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
        t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
        return list(torch.split(t_squeezed, self.seq_len))


class MHACallable(Protocol):
    def __call__(
        self,
        *qkv_metadata_and_caches: Union[torch.Tensor, Constant],
    ) -> torch.Tensor: ...


class PrepareMetadataCallable(Protocol):
    def __call__(
        self,
        position_ids: torch.Tensor,
        seq_len: torch.Tensor,
        input_pos: torch.Tensor,
        cache_loc: torch.Tensor,
        pages_per_seq: torch.Tensor,
        slot_idx: torch.Tensor,
        page_size: int,
    ) -> List[torch.Tensor]: ...


class GetCacheCallable(Protocol):
    def __call__(self, sequence_info: SequenceInfo) -> torch.Tensor: ...


class GetBufferCallable(GetCacheCallable):
    pass


CacheInitializerDict = Dict[str, GetCacheCallable]
BufferInitializerDict = Dict[str, GetBufferCallable]
AttentionLayout = Literal["bsnd", "bnsd"]


class AttentionDescriptor(ABC):
    """An interface to define a functional attention operator.

    The main logic is contained with the actual attention op as well as the prepare_metadata op. The
    prepare_metadata op is responsible for converting the standardized sequence info into metadata
    specific to the attention op.
    """

    @classmethod
    @abstractmethod
    def is_paged(cls) -> bool:
        """Return if the attention op is paged or not."""

    @classmethod
    @abstractmethod
    def get_attention_layout(cls) -> AttentionLayout:
        """Get the attention layout expected by the source op and the cached attention op."""

    @classmethod
    @abstractmethod
    def get_num_qkv_args(cls) -> int:
        """Get the number of qkv arguments expected by the source op."""

    @classmethod
    @abstractmethod
    def get_source_attention_op(cls) -> OpOverloadPacket:
        """Get the source attention op that we target for replacement."""

    @classmethod
    @abstractmethod
    def get_cached_attention_op(cls) -> MHACallable:
        """Get the cached attention op .

        The attention_op should follow the below signature:

        ```
        def attention_op(
            *qkv,       # list of tensors corresponding to Q, K, V as in source attention op
            *metadata,  # global info about the sequences as returned by the prepare_metadata op
            *caches,    # contains layer-specific caches per provided cache initializers
            *buffers,   # global buffers used by the attention op as provided by buffer initializers
            *constants, # basic arguments (int, float, str, None) added as CONSTANTS in the graph
        ) -> torch.Tensor: ...
        ```

        **Note that the attention op should be a valid torch custom op, which comes with
        restrictions on the supported types in the signature.**

        **Note that the `qkv` tuple should be consistent across both the cached attention
        op and the source attention op that it is replacing.**

        """
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
        """Get the prepare_metadata op.

        The prepare_metadata op should follow the below signature:

        ```
        def prepare_metadata(
            position_ids: torch.Tensor,
            seq_len: torch.Tensor,
            input_pos: torch.Tensor,
            cache_loc: torch.Tensor,
            pages_per_seq: torch.Tensor,
            slot_idx: torch.Tensor,
            page_size: int,
        ) -> List[torch.Tensor]: ...
        ```
        The metadata should contain all necessary global information required for the underlying
        attention op to process the input sequence and the returned list of tensors will be passed
        on to each invocation of the attention op in the graph.

        prepare_metadata is called once at the beginning of the forward pass.

        **Note that the prepare_metadata op should be a valid torch custom op, which comes with
        restrictions on the supported types in the signature.**
        """

    @classmethod
    @abstractmethod
    def get_cache_initializers(
        cls, source_attn_node: Node, cache_config: CacheConfig
    ) -> CacheInitializerDict:
        """Provide a dictionary of function pointers that can be used to initialize the caches.

        The key corresponds to the argument name used in the attention op signature. The function
        key doesn't need to be unique across multiple attention nodes in the graph. The key used to
        describe the cache in the graph will be patched with the attention node index to ensure
        uniqueness.

        ``get_cache_initializers`` will be called *once* during cache initialization and before
        the initial forward pass for each attention op detected in the graph. The caches will be
        managed by the global CacheManager and passed back to the attention op during the forward
        pass.

        If the cache initializer requires information about the attention op, it can retrieve
        the necessary information from the source attention node and cache config.
        """

    @classmethod
    def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
        """Provide a dictionary of function pointers that can be used to initialize buffers.

        The key corresponds to the buffer name used in the graph module and will **not**
        be patched unlike a cache key. Hence, it is a **global** key that is shared across all
        attention ops in the model much like a regular buffer in an nn.Module. That means if this
        i/f is called for multiple attention ops, the same buffer will be shared across all of them
        if this function provides the same key multiple times.

        Buffers are initialize *once* after the model initialization and before the initial forward
        pass for each attention op detected in the graph. The buffer will be managed by the global
        CacheManager and passed back to the attention op during the forward pass.

        If the buffer initializer requires information about the attention op, it can retrieve
        the necessary information from the source attention node.
        """

    @classmethod
    @abstractmethod
    def get_constants(cls, source_attn_node: Node) -> List[Constant]:
        """Provide a list of constant arguments to be passed to the attention op.

        The constant arguments are passed to the attention op as additional arguments after the
        caches and buffers. The constants are expected to be of type int, float, str, or None.
        """


class AttentionRegistry:
    """A simple registry to look up different attention implementations."""

    _attention_registry: Dict[str, Type["AttentionDescriptor"]] = {}

    @classmethod
    def register(cls, kernel_source: str) -> Type["AttentionDescriptor"]:
        def decorator(attention_cls: Type["AttentionDescriptor"]):
            assert kernel_source not in cls._attention_registry, (
                f"Attention source {kernel_source} already registered."
            )
            cls._attention_registry[kernel_source] = attention_cls
            return attention_cls

        return decorator

    @classmethod
    def get(cls, kernel_source: str) -> Type["AttentionDescriptor"]:
        assert cls.has(kernel_source), f"Attention source {kernel_source} not registered."
        return cls._attention_registry[kernel_source]

    @classmethod
    def has(cls, kernel_source: str) -> bool:
        return kernel_source in cls._attention_registry
